This Notebook steps :
Create a Spark session
Import training images to proceed, as binary, in a Spark DataFrame
Labeling images, by fruit name, extracted from images path
Enhance image by tweeking color, sharpness, contrast, brightness
Extract 2048 features array by tranfert learning, using Keras Resnet50 CNN
Stores path, label and feature array partitionned by label in parquet format file on S3
import pandas as pd
from PIL import Image
from PIL import ImageEnhance
import numpy as np
import io
import os
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, pandas_udf, udf, PandasUDFType, size
from pyspark.sql import types
from pyspark import SparkContext, SparkConf
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras.preprocessing.image import img_to_array
# Constants
WORKERS = 'local[2]'
LOAD_PATH = 's3a://fruits-images-to-proceed/Training_apples/'
SAVE_PATH = 's3a://fruits-images-proceded/Training_apples_featured.parquet'
# create a spark session
spark = (SparkSession.builder
.master( WORKERS )
.appName('Feature Training on apples')
.config('spark.driver.extraClassPath',
'/home/ec2-user/hadoop/share/hadoop/tools/lib/aws-java-sdk-bundle-1.11.375.jar:/home/ec2-user/hadoop/share/hadoop/tools/lib/hadoop-aws-3.2.0.jar')
.config('spark.executor.heartbeatInterval', '300000')
.config('spark.network.timeout', '900000')
.config('spark.sql.execution.arrow.pyspark.enabled', 'true')
.getOrCreate()
)
# Read jpg images as binary, recursively
images =(spark
.read
.format('binaryFile')
.option('pathGlobFilter', '*.jpg')
.option('recursiveFileLookup', 'true')
.load(LOAD_PATH)
)
images.printSchema()
root |-- path: string (nullable = true) |-- modificationTime: timestamp (nullable = true) |-- length: long (nullable = true) |-- content: binary (nullable = true)
# Total number of images
totalMunber = images.count()
print('Total number of images in train set {}'.format(totalMunber))
Total number of images in train set 1700
# Offset of starting image name
label_offset = len(LOAD_PATH)
#Get only fruit name from path
col_label = udf(lambda s : extract_label(s), types.StringType())
def extract_label(s):
last = s[label_offset :]
return last[:last.rfind('/')]
# Create label column
images = images.withColumn('label',col_label(images.path))
images.printSchema()
root |-- path: string (nullable = true) |-- modificationTime: timestamp (nullable = true) |-- length: long (nullable = true) |-- content: binary (nullable = true) |-- label: string (nullable = true)
#Get only fruit name from path
images.select('label').show(truncate=False)
+--------------+ |label | +--------------+ |Apple Braeburn| |Apple Braeburn| |Apple Braeburn| |Apple Braeburn| |Apple Braeburn| |Apple Braeburn| |Apple Braeburn| |Apple Braeburn| |Apple Braeburn| |Apple Braeburn| |Apple Braeburn| |Apple Braeburn| |Apple Braeburn| |Apple Red 1 | |Apple Red 1 | |Apple Braeburn| |Apple Red 1 | |Apple Braeburn| |Apple Braeburn| |Apple Braeburn| +--------------+ only showing top 20 rows
# By label count
print('By label images count :')
images.groupBy('label').count().show()
By label images count : +---------------+-----+ | label|count| +---------------+-----+ | Apple Golden 3| 481| | Apple Red 2| 492| |Apple Pink Lady| 231| | Apple Red 1| 248| | Apple Braeburn| 248| +---------------+-----+
# Take a random sample from images
sampleList = images.sample(0.01).select('content').toPandas()
# Image samples
imgs = [Image.open(io.BytesIO(bin)).resize([224, 224])
for bin in sampleList.content]
# Display
for i,img in enumerate(imgs):
print('Sample_'+str(i))
img.show()
Sample_0
Sample_1
Sample_2
Sample_3
Sample_4
Sample_5
Sample_6
Sample_7
Sample_8
Sample_9
Sample_10
Sample_11
Sample_12
colorEnhancer = ImageEnhance.Color(imgs[0])
for i in range(40):
factor = i / 8
print(f'Color {factor:f}')
colorEnhancer.enhance(factor).show()
Color 0.000000
Color 0.125000
Color 0.250000
Color 0.375000
Color 0.500000
Color 0.625000
Color 0.750000
Color 0.875000
Color 1.000000
Color 1.125000
Color 1.250000
Color 1.375000
Color 1.500000
Color 1.625000
Color 1.750000
Color 1.875000
Color 2.000000
Color 2.125000
Color 2.250000
Color 2.375000
Color 2.500000
Color 2.625000
Color 2.750000
Color 2.875000
Color 3.000000
Color 3.125000
Color 3.250000
Color 3.375000
Color 3.500000
Color 3.625000
Color 3.750000
Color 3.875000
Color 4.000000
Color 4.125000
Color 4.250000
Color 4.375000
Color 4.500000
Color 4.625000
Color 4.750000
Color 4.875000
sharpnessEnhancer = ImageEnhance.Sharpness(imgs[0])
for i in range(40):
factor = i / 8
print(f'Sharpness {factor:f}')
sharpnessEnhancer.enhance(factor).show()
Sharpness 0.000000
Sharpness 0.125000
Sharpness 0.250000
Sharpness 0.375000
Sharpness 0.500000
Sharpness 0.625000
Sharpness 0.750000
Sharpness 0.875000
Sharpness 1.000000
Sharpness 1.125000
Sharpness 1.250000
Sharpness 1.375000
Sharpness 1.500000
Sharpness 1.625000
Sharpness 1.750000
Sharpness 1.875000
Sharpness 2.000000
Sharpness 2.125000
Sharpness 2.250000
Sharpness 2.375000
Sharpness 2.500000
Sharpness 2.625000
Sharpness 2.750000
Sharpness 2.875000
Sharpness 3.000000
Sharpness 3.125000
Sharpness 3.250000
Sharpness 3.375000
Sharpness 3.500000
Sharpness 3.625000
Sharpness 3.750000
Sharpness 3.875000
Sharpness 4.000000
Sharpness 4.125000
Sharpness 4.250000
Sharpness 4.375000
Sharpness 4.500000
Sharpness 4.625000
Sharpness 4.750000
Sharpness 4.875000
contrastEnhancer = ImageEnhance.Contrast(imgs[0])
for i in range(40):
factor = i / 8
print(f'Contrast {factor:f}')
contrastEnhancer.enhance(factor).show()
Contrast 0.000000
Contrast 0.125000
Contrast 0.250000
Contrast 0.375000
Contrast 0.500000
Contrast 0.625000
Contrast 0.750000
Contrast 0.875000
Contrast 1.000000
Contrast 1.125000
Contrast 1.250000
Contrast 1.375000
Contrast 1.500000
Contrast 1.625000
Contrast 1.750000
Contrast 1.875000
Contrast 2.000000
Contrast 2.125000
Contrast 2.250000
Contrast 2.375000
Contrast 2.500000
Contrast 2.625000
Contrast 2.750000
Contrast 2.875000
Contrast 3.000000
Contrast 3.125000
Contrast 3.250000
Contrast 3.375000
Contrast 3.500000
Contrast 3.625000
Contrast 3.750000
Contrast 3.875000
Contrast 4.000000
Contrast 4.125000
Contrast 4.250000
Contrast 4.375000
Contrast 4.500000
Contrast 4.625000
Contrast 4.750000
Contrast 4.875000
brigthnessEnhancer = ImageEnhance.Brightness(imgs[0])
for i in range(40):
factor = i / 8
print(f'Brightness {factor:f}')
brigthnessEnhancer.enhance(factor).show()
Brightness 0.000000
Brightness 0.125000
Brightness 0.250000
Brightness 0.375000
Brightness 0.500000
Brightness 0.625000
Brightness 0.750000
Brightness 0.875000
Brightness 1.000000
Brightness 1.125000
Brightness 1.250000
Brightness 1.375000
Brightness 1.500000
Brightness 1.625000
Brightness 1.750000
Brightness 1.875000
Brightness 2.000000
Brightness 2.125000
Brightness 2.250000
Brightness 2.375000
Brightness 2.500000
Brightness 2.625000
Brightness 2.750000
Brightness 2.875000
Brightness 3.000000
Brightness 3.125000
Brightness 3.250000
Brightness 3.375000
Brightness 3.500000
Brightness 3.625000
Brightness 3.750000
Brightness 3.875000
Brightness 4.000000
Brightness 4.125000
Brightness 4.250000
Brightness 4.375000
Brightness 4.500000
Brightness 4.625000
Brightness 4.750000
Brightness 4.875000
# Enhance image
def enhance(img,
color = 1.25,
sharpness = 4.5,
contrast = 1.25,
brigthness= 1.5):
colorEnhancer = ImageEnhance.Color(img)
img = colorEnhancer.enhance(color)
sharpnessEnhancer = ImageEnhance.Sharpness(img)
sharpnessEnhancer.enhance(sharpness)
contrastEnhancer = ImageEnhance.Contrast(img)
contrastEnhancer.enhance(contrast)
brigthnessEnhancer = ImageEnhance.Brightness(img)
brigthnessEnhancer.enhance(brigthness)
return img
for i,img in enumerate(imgs):
print ('Sample_{} original'.format(str(i)))
display(img)
print ('Sample_{} enhanced'.format(str(i)))
display(enhance(img))
Sample_0 original
Sample_0 enhanced
Sample_1 original
Sample_1 enhanced
Sample_2 original
Sample_2 enhanced
Sample_3 original
Sample_3 enhanced
Sample_4 original
Sample_4 enhanced
Sample_5 original
Sample_5 enhanced
Sample_6 original
Sample_6 enhanced
Sample_7 original
Sample_7 enhanced
Sample_8 original
Sample_8 enhanced
Sample_9 original
Sample_9 enhanced
Sample_10 original
Sample_10 enhanced
Sample_11 original
Sample_11 enhanced
Sample_12 original
Sample_12 enhanced
def model_fn():
'''
Returns a ResNet50 model with top layer removed and broadcasted pretrained weights.
'''
resnet_full = ResNet50()
resnet = Model(inputs = resnet_full.inputs,
outputs = resnet_full.layers[-2].output)
return resnet
def preprocess(content):
'''
Preprocesses raw image bytes for prediction.
'''
# load raw image from dataframe and resize it to ResNet specifications
img = Image.open(io.BytesIO(content)).resize([224, 224])
# Enhance image
img = enhance(img)
# image to Tensor array
arr = img_to_array(img)
# return ResNet50 preprocessed image
return preprocess_input(arr)
def featurize_series(model, content_series):
'''
Featurize a pd.Series of raw images using the input model.
:return: a pd.Series of image features
'''
# input = np.stack(content_series.map(preprocess))
input = tf.convert_to_tensor(np.stack(content_series.map(preprocess)), dtype=tf.float32)
# features from image
preds = model.predict(input)
# For some layers, output features will be multi-dimensional tensors.
# We flatten the feature tensors to vectors for easier storage in Spark DataFrames.
output = [p.flatten() for p in preds]
# return features vector
return pd.Series(output)
from typing import Iterator
@pandas_udf('array<float>')
def featurize_udf(content_series_iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
'''
This method is a Scalar Iterator pandas UDF wrapping our featurization function.
The decorator specifies that this returns a Spark DataFrame column of type ArrayType(FloatType).
:param content_series_iter: This argument is an iterator over batches of data, where each batch
is a pandas Series of image data.
'''
# With Scalar Iterator pandas UDFs, we can load the model once and then re-use it
# for multiple data batches. This amortizes the overhead of loading big models.
model = model_fn()
for content_series in content_series_iter:
yield featurize_series(model, content_series)
Apply featurization to the DataFrame of images
# Avoiding Out Of Memory (OOM) errors by reducing the Arrow batch size
spark.conf.set('spark.sql.execution.arrow.maxRecordsPerBatch', '128')
# Transfert learning
images = images.withColumn('features', featurize_udf(images.content))
images.printSchema()
root |-- path: string (nullable = true) |-- modificationTime: timestamp (nullable = true) |-- length: long (nullable = true) |-- content: binary (nullable = true) |-- label: string (nullable = true) |-- features: array (nullable = true) | |-- element: float (containsNull = true)
# Save Spark DataFrame, partitionned by label, in S3 Bucket
(images
.select('path','label','features')
.write
.partitionBy('label')
.mode('overwrite')
.parquet(SAVE_PATH)
)
spark.stop()